package cool.mtc.core.util;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.lang.reflect.Field;
import java.sql.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * @author 明河
 */
@Slf4j
public abstract class JdbcUtil {
    private static final List<Class<?>> LIST = new ArrayList<>();

    static {
        LIST.add(String.class);
        LIST.add(Integer.class);
    }

    /**
     * 执行sql并返回结果
     *
     * @param dbType      数据库类型
     * @param db          数据库名
     * @param username    用户名
     * @param password    用户密码
     * @param sql         sql语句
     * @param params      参数
     * @param targetClass 返回的实体类型
     */
    public static <T> List<T> exec(DbTypeEnum dbType, String ip, String port, String db, String username, String password, String sql, Object[] params, Class<T> targetClass) throws SQLException, ClassNotFoundException {
        Connection connection = getConnection(dbType, ip, port, db, username, password);
        return exec(connection, sql, params, targetClass);
    }

    public static <T> List<T> exec(String dbUrl, String username, String password, String sql, Object[] params, Class<T> targetClass) throws SQLException {
        Connection connection = getConnection(dbUrl, username, password);
        return exec(connection, sql, params, targetClass);
    }

    public static <T> List<T> exec(Connection connection, String sql, Object[] params, Class<T> targetClass) throws SQLException {
        PreparedStatement statement = connection.prepareStatement(sql);
        if (CollectionUtil.isNotEmpty(Arrays.asList(params))) {
            for (int i = 0; i < params.length; i++) {
                statement.setObject(i + 1, params[i]);
            }
        }
        ResultSet resultSet = statement.executeQuery();
        List<T> list = transResultSetToObject(resultSet, targetClass);
        close(resultSet);
        close(statement);
        close(connection);
        return list;
    }

    /**
     * 获取数据库连接
     */
    public static Connection getConnection(DbTypeEnum dbType, String db, String username, String password) throws SQLException, ClassNotFoundException {
        return getConnection(dbType, dbType.getDefaultIp(), dbType.getDefaultPort(), db, username, password);
    }

    /**
     * 获取数据连接
     */
    public static Connection getConnection(DbTypeEnum dbType, String ip, String port, String db, String username, String password) throws ClassNotFoundException, SQLException {
        Class.forName(dbType.getDriverClassName());
        return getConnection(getDBUrl(dbType, ip, port, db), username, password);
    }

    public static Connection getConnection(String dbUrl, String username, String password) throws SQLException {
        return DriverManager.getConnection(dbUrl, username, password);
    }

    /**
     * 关闭数据库连接
     */
    public static void close(Connection connection) {
        if (null == connection) {
            return;
        }
        try {
            connection.close();
        } catch (SQLException ex) {
            log.error("关闭Connection失败", ex);
        }
    }

    public static void close(Statement statement) {
        if (null == statement) {
            return;
        }
        try {
            statement.close();
        } catch (SQLException ex) {
            log.error("关闭Statement失败", ex);
        }
    }

    public static void close(ResultSet resultSet) {
        if (null == resultSet) {
            return;
        }
        try {
            resultSet.close();
        } catch (SQLException ex) {
            log.error("关闭ResultSet失败", ex);
        }
    }

    /**
     * 转换查询结果为实体类型
     */
    private static <T> List<T> transResultSetToObject(ResultSet resultSet, Class<T> clazz) throws SQLException {
        if (null == resultSet) {
            return new ArrayList<>();
        }
        List<T> list = new ArrayList<>();
        if (LIST.contains(clazz)) {
            while (resultSet.next()) {
                list.add(resultSet.getObject(1, clazz));
            }
            return list;
        }
        Field[] fields = clazz.getDeclaredFields();
        while (resultSet.next()) {
            // 创建对象
            try {
                T t = clazz.newInstance();
                for (Field field : fields) {
                    // 授权
                    field.setAccessible(true);
                    // 获取新建对象中此字段
                    Field targetField = t.getClass().getDeclaredField(field.getName());
                    // 获取查询结果中的值
                    Object obj;
                    try {
                        obj = resultSet.getObject(field.getName(), field.getType());
                    } catch (SQLException ex) {
                        log.warn("从查询结果中读取属性值[{}]失败。", field.getName());
                        continue;
                    }
                    // 赋值
                    targetField.setAccessible(true);
                    targetField.set(t, obj);
                }
                list.add(t);
            } catch (IllegalAccessException | InstantiationException ex) {
                log.error("查询结果转换失败。", ex);
                break;
            } catch (NoSuchFieldException ignore) {
            }
        }
        return list;
    }

    /**
     * 根据数据库类型获取数据库连接url
     */
    private static String getDBUrl(DbTypeEnum dbType, String ip, String port, String db) {
        return dbType.getUrlTemplate().replace("{ip}", ip)
                .replace("{port}", port)
                .replace("{db}", db);
    }

    @Getter
    @RequiredArgsConstructor
    public enum DbTypeEnum {

        MYSQL("com.mysql.jdbc.Driver", "jdbc:mysql://{ip}:{port}/{db}?useSSL=false", "localhost", "3306"),
        ;

        private final String driverClassName;
        private final String urlTemplate;
        private final String defaultIp;
        private final String defaultPort;
    }
}
